{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Different sphere graph construction and analysis\n",
    "* Equiangular with different methods\n",
    "* Polyhedron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import hashlib\n",
    "import zipfile\n",
    "\n",
    "import numpy as np\n",
    "from scipy import sparse\n",
    "import matplotlib.pyplot as plt\n",
    "import healpy as hp\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [10, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy import sparse\n",
    "\n",
    "from pygsp import utils\n",
    "from pygsp.graphs import Graph\n",
    "from pygsp.graphs import NNGraph\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HEALPix analysis\n",
    "(With pyGSP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SphereHealpixNN(NNGraph):\n",
    "    def __init__(self, nside, n_neighbors, nest=True, **kwargs):\n",
    "        import healpy as hp\n",
    "        self.nside = nside\n",
    "        npix = 12 * nside**2\n",
    "        indexes = np.arange(npix)\n",
    "        x, y, z = hp.pix2vec(nside, indexes, nest=nest)\n",
    "        coords = np.vstack([x, y, z]).transpose()\n",
    "        coords = np.asarray(coords, dtype=np.float32)\n",
    "        n_neighbors = 6 if nside==1 else 8\n",
    "        \n",
    "        opt_std = {1: 0.5, 2: 0.15, 4: 0.05, 8: 0.0125, 16: 0.005, 32: 0.001}\n",
    "        try:\n",
    "            sigma = opt_std[nside]\n",
    "        except:\n",
    "            sigma = 0.001\n",
    "        \n",
    "        plotting = {\n",
    "            'vertex_size': 80,\n",
    "            \"limits\": np.array([-1, 1, -1, 1, -1, 1])\n",
    "        }\n",
    "        \n",
    "        super(SphereHealpixNN, self).__init__(coords, k=n_neighbors, center=False, rescale=False,\n",
    "                                     sigma=sigma, plotting=plotting, **kwargs)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SphereHealpix(Graph):\n",
    "    def __init__(self, nside, nest=True, **kwargs):\n",
    "        import healpy as hp\n",
    "        self.nside = nside\n",
    "        npix = 12 * nside**2\n",
    "        indexes = np.arange(npix)\n",
    "        x, y, z = hp.pix2vec(nside, indexes, nest=nest)\n",
    "        coords = np.vstack([x, y, z]).transpose()\n",
    "        coords = np.asarray(coords, dtype=np.float32)\n",
    "        \n",
    "        neighbors = hp.pixelfunc.get_all_neighbours(nside, indexes, nest=nest)\n",
    "        col_index = neighbors.T.reshape((npix * 8))\n",
    "        row_index = np.repeat(indexes, 8)\n",
    "        keep = (col_index < npix)\n",
    "        # Remove fake neighbors (some pixels have less than 8).\n",
    "        keep &= (col_index >= 0)\n",
    "        col_index = col_index[keep]\n",
    "        row_index = row_index[keep]\n",
    "        distances = np.sum((coords[row_index] - coords[col_index])**2, axis=1)\n",
    "        plt.hist(distances, 100)\n",
    "        kernel_width = np.mean(distances)\n",
    "        weights = np.exp(-distances / (2 * kernel_width))\n",
    "        W = sparse.csr_matrix(\n",
    "            (weights, (row_index, col_index)), shape=(npix, npix), dtype=np.float32)\n",
    "        \n",
    "        plotting = {\"limits\": np.array([-1, 1, -1, 1, -1, 1])}\n",
    "        super(SphereHealpix, self).__init__(W=W, coords=coords,\n",
    "                                     plotting=plotting, **kwargs)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphHp = SphereHealpix(8, True)\n",
    "graphHpNN = SphereHealpixNN(8, 8, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphHp.plot()\n",
    "plt.figure()\n",
    "plt.spy(graphHp.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphHpNN.plot()\n",
    "plt.figure()\n",
    "plt.spy(graphHpNN.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(graphHpNN.e[:25], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphHp.compute_laplacian(\"combinatorial\")\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "graphHp.set_coordinates(graphHp.U[:,1:4])\n",
    "graphHp.plot(vertex_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Equiangular graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SphereEquiangularNN(NNGraph):\n",
    "    def __init__(self, bw=64, sptype='DH', pole='zero', neighbors=8, weights=None,  **kwargs):\n",
    "        \"\"\"Sphere with an equiangular sampling\n",
    "         Parameters\n",
    "        ----------\n",
    "        bw : int\n",
    "            bandwidth, size of grid  (default = 64)\n",
    "        sptype: str\n",
    "            sampling type, possible arguments are 'DH', 'SOFT', ... (default = 'DH')\n",
    "        pole: str\n",
    "            how to manage the pole of the cylinder, possible arguments are 'disconnected', 'connected' (default = 'disconnected')\n",
    "        ---------\n",
    "        TODO: unique on neighbor\n",
    "        TODO: CC and GL are not equiangular and must be implement in other ways\n",
    "        \"\"\"\n",
    "        self.bw = bw\n",
    "        self.sptype = sptype\n",
    "        self.pole = pole\n",
    "        if pole not in ['all', 'zero', 'one']:\n",
    "            raise ValueError('Unknown pole value:' + pole) \n",
    "        if sptype not in ['DH', 'SOFT', 'CC', 'GL', 'OD']:\n",
    "            raise ValueError('Unknown sampling type:' + sptype) \n",
    "        if sptype is not 'DH' and pole is not 'zero':\n",
    "            print('pole can be only zero with sampling type ' + sptype)\n",
    "            pole = 'zero'\n",
    "        \n",
    "        ## sampling and coordinates calculation\n",
    "        if sptype is 'DH':\n",
    "            beta = np.arange(2 * bw) * np.pi / (2. * bw)  # Driscoll-Heally\n",
    "            alpha = np.arange(2 * bw) * np.pi / bw\n",
    "            if pole is 'zero':\n",
    "                beta = np.arange(2 * bw + 1) * np.pi / (2. * bw + 1)\n",
    "        elif sptype is 'SOFT':  # SO(3) Fourier Transform optimal\n",
    "            beta = np.pi * (2 * np.arange(2 * bw) + 1) / (4. * bw)\n",
    "            alpha = np.arange(2 * bw) * np.pi / bw\n",
    "        elif sptype == 'CC':  # Clenshaw-Curtis\n",
    "            beta = np.linspace(0, np.pi, 2 * bw + 1)\n",
    "            alpha = np.linspace(0, 2 * np.pi, 2 * bw + 2, endpoint=False)\n",
    "        elif sptype == 'GL':  # Gauss-legendre\n",
    "            from numpy.polynomial.legendre import leggauss\n",
    "            x, _ = leggauss(bw + 1)  # TODO: leggauss docs state that this may not be only stable for orders > 100\n",
    "            beta = np.arccos(x)\n",
    "            alpha = np.arange(2 * bw + 2) * np.pi / (bw + 1)\n",
    "        if pole is not 'all' and sptype is 'DH':\n",
    "            beta = beta[1:]\n",
    "        if sptype == 'OD':  # Optimal Dimensionality\n",
    "            theta, phi = np.zeros(4*bw**2), np.zeros(4*bw**2)\n",
    "            index=0\n",
    "            #beta = np.pi * (2 * np.arange(2 * bw) + 1) / (4. * bw)\n",
    "            beta = np.pi * ((np.arange(2 * bw + 1)%2)*(4*bw-1)+np.arange(2 * bw + 1)*-1**(np.arange(2 * bw + 1)%2)) / (4 * bw - 1)\n",
    "            for i in range(2*bw):\n",
    "                alpha = 2 * np.pi * np.arange(2 * i + 1) / (2 * i + 1)\n",
    "                end = len(alpha)\n",
    "                theta[index:index+end], phi[index:index+end] = np.repeat(beta[i], end), alpha\n",
    "                index += end\n",
    "        else:\n",
    "            theta, phi = np.meshgrid(*(beta, alpha),indexing='ij')\n",
    "        ct = np.cos(theta).flatten()\n",
    "        st = np.sin(theta).flatten()\n",
    "        cp = np.cos(phi).flatten()\n",
    "        sp = np.sin(phi).flatten()\n",
    "        x = st * cp\n",
    "        y = st * sp\n",
    "        z = ct\n",
    "        coords = np.vstack([x, y, z]).T\n",
    "        if pole is 'one':\n",
    "            coords = np.vstack([[0., 0., 1.],coords])\n",
    "        coords = np.asarray(coords, dtype=np.float32)\n",
    "        self.npix = len(coords)\n",
    "        if neighbors == 'all':\n",
    "            neighbors = self.npix-1\n",
    "        \n",
    "        plotting = {\"limits\": np.array([-1, 1, -1, 1, -1, 1])}\n",
    "        super(SphereEquiangularNN, self).__init__(coords, k=neighbors, center=False, rescale=False,\n",
    "                                     plotting=plotting, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "equiNN = SphereEquiangularNN(4, 'OD', pole='zero', neighbors=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "equiNN.plot()\n",
    "plt.figure()\n",
    "plt.spy(equiNN.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(equiNN.e[:16], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "equiNN.compute_laplacian(\"normalized\")\n",
    "equiNN.compute_fourier_basis(recompute=True)\n",
    "equiNN.set_coordinates(equiNN.U[:,1:4])\n",
    "equiNN.plot(vertex_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import spatial\n",
    "class SphereEquiangular(Graph):\n",
    "    def __init__(self, bw=64, sptype='DH', pole='disconnected', neighbors=8, w_mat=None, dist='geodesic', \n",
    "                 affine=0., geometry='sphere', delta='one', **kwargs):\n",
    "        \"\"\"Sphere with an equiangular sampling\n",
    "         Parameters\n",
    "        ----------\n",
    "        bw : int\n",
    "            bandwidth, size of grid  (default = 64)\n",
    "        sptype: str\n",
    "            sampling type, possible arguments are 'DH', 'SOFT', ... (default = 'DH')\n",
    "        pole: str\n",
    "            how to manage the pole of the cylinder, possible arguments are 'disconnected', 'connected' (default = 'disconnected')\n",
    "        ---------\n",
    "        TODO: unique on neighbor\n",
    "        TODO: CC and GL are not equiangular and must be implement in other ways\n",
    "        \"\"\"\n",
    "        self.bw = bw\n",
    "        self.sptype = sptype\n",
    "        self.pole = pole\n",
    "        if pole not in ['disconnected', 'connected']:\n",
    "            raise ValueError('Unknown pole value:' + pole) \n",
    "        if sptype not in ['DH', 'SOFT', 'CC', 'GL']:\n",
    "            raise ValueError('Unknown sampling type:' + sptype) \n",
    "        if neighbors not in [4, 8, 'all', 'full']:\n",
    "            raise ValueError('impossible numbers of neighbors:' + neighbors) \n",
    "        \n",
    "        ## sampling and coordinates calculation\n",
    "        if sptype is 'DH':\n",
    "            beta = np.arange(2 * bw) * np.pi / (2. * bw)  # Driscoll-Heally\n",
    "            alpha = np.arange(2 * bw) * np.pi / bw\n",
    "        elif sptype is 'SOFT':  # SO(3) Fourier Transform optimal\n",
    "            beta = np.pi * (2 * np.arange(2 * bw) + 1) / (4. * bw)\n",
    "            alpha = np.arange(2 * bw) * np.pi / bw\n",
    "        elif sptype == 'CC':  # Clenshaw-Curtis\n",
    "            beta = np.linspace(0, np.pi, 2 * bw + 1)\n",
    "            alpha = np.linspace(0, 2 * np.pi, 2 * bw + 2, endpoint=False)\n",
    "        elif sptype == 'GL':  # Gauss-legendre\n",
    "            from numpy.polynomial.legendre import leggauss\n",
    "            x, _ = leggauss(bw + 1)  # TODO: leggauss docs state that this may not be only stable for orders > 100\n",
    "            beta = np.arccos(x)\n",
    "            alpha = np.arange(2 * bw + 2) * np.pi / (bw + 1)\n",
    "        theta, phi = np.meshgrid(*(beta, alpha),indexing='ij')\n",
    "        self.lat, self.lon = theta.shape\n",
    "        # do we want cylinder coordinates?\n",
    "        if geometry == 'sphere':\n",
    "            ct = np.cos(theta).flatten()\n",
    "            st = np.sin(theta).flatten()\n",
    "        elif geometry == 'cylinder':\n",
    "            ct = theta.flatten() * 2 * bw / np.pi\n",
    "            st = 1\n",
    "        cp = np.cos(phi).flatten()\n",
    "        sp = np.sin(phi).flatten()\n",
    "        x = st * cp\n",
    "        y = st * sp\n",
    "        z = ct\n",
    "        coords = np.vstack([x, y, z]).T\n",
    "        coords = np.asarray(coords, dtype=np.float32)\n",
    "        self.npix = len(coords)\n",
    "        \n",
    "        if neighbors=='full':\n",
    "            self.coords = coords\n",
    "            distances = spatial.distance.cdist(coords, coords)**2\n",
    "            weights = 1 / distances\n",
    "            for i in range(np.alen(weights)):\n",
    "                weights[i, i] = 0.\n",
    "            W = sparse.csr_matrix(weights, dtype=np.float32)\n",
    "            plotting = {\"limits\": np.array([-1, 1, -1, 1, -1, 1])}\n",
    "            super(CylinderEquiangular, self).__init__(W=W, coords=coords,\n",
    "                                     plotting=plotting, **kwargs)\n",
    "            return\n",
    "        \n",
    "        ## neighbors and weight matrix calculation\n",
    "        def one(x):\n",
    "            return 1\n",
    "        if delta == 'one':\n",
    "            fun = one\n",
    "        else:\n",
    "            def fun(x):\n",
    "                lat = abs(x//self.lat - self.lat//2)\n",
    "                delta = 1+lat*(self.lon-1)/self.lat\n",
    "                return int(delta)\n",
    "        \n",
    "        def south(x):\n",
    "            if x >= self.npix - self.lat:\n",
    "                if pole == 'connected':\n",
    "                    return (x + self.lat//2)%self.lat + self.npix - self.lat\n",
    "                else:\n",
    "                    return -1\n",
    "            return x + self.lon\n",
    "\n",
    "        def north(x):\n",
    "            if x < self.lat:\n",
    "                if pole == 'connected':\n",
    "                    return (x + self.lat//2)%self.lat\n",
    "                else:\n",
    "                    return -1\n",
    "            return x - self.lon\n",
    "\n",
    "        def west(x, fun=fun):\n",
    "            delta = fun(x)\n",
    "            if x%(self.lon)<delta:\n",
    "                try:\n",
    "                    assert x//self.lat == (x-delta+self.lon)//self.lat\n",
    "                except:\n",
    "                    print(x)\n",
    "                    print(delta)\n",
    "                    print(x-delta+self.lon)\n",
    "                    print(x//self.lat)\n",
    "                    print((x-delta+self.lon)//self.lat)\n",
    "                    raise\n",
    "                x += self.lon\n",
    "            else:\n",
    "                try:\n",
    "                    assert x//self.lat == (x-delta)//self.lat\n",
    "                except:\n",
    "                    print(x)\n",
    "                    print(delta)\n",
    "                    print(x-delta)\n",
    "                    print(x//self.lat)\n",
    "                    print((x-delta)//self.lat)\n",
    "                    raise\n",
    "            return x - delta\n",
    "\n",
    "        def east(x, fun=fun):\n",
    "            delta = fun(x)\n",
    "            if x%(self.lon)>=self.lon-delta:\n",
    "                try:\n",
    "                    assert x//self.lat == (x+delta-self.lon)//self.lat\n",
    "                except:\n",
    "                    print(x)\n",
    "                    print(delta)\n",
    "                    print(x+delta-self.lon)\n",
    "                    print(x//self.lat)\n",
    "                    print((x+delta-self.lon)//self.lat)\n",
    "                    raise\n",
    "                x -= self.lon\n",
    "            else:\n",
    "                try:\n",
    "                    assert x//self.lat == (x+delta)//self.lat\n",
    "                except:\n",
    "                    print(x)\n",
    "                    print(delta)\n",
    "                    print(x+delta)\n",
    "                    print(x//self.lat)\n",
    "                    print((x+delta)//self.lat)\n",
    "                    raise\n",
    "            return x + delta\n",
    "\n",
    "        col_index=[]\n",
    "        for ind in range(self.npix):\n",
    "            # first line is the same point, so is connected to all points of second line\n",
    "            if neighbors==8:\n",
    "                neighbor = [south(west(ind)), west(ind), north(west(ind)), north(ind), \n",
    "                            north(east(ind)), east(ind), south(east(ind)), south(ind)]\n",
    "            elif neighbors==4:\n",
    "                neighbor = [west(ind), north(ind), east(ind), south(ind)]\n",
    "            elif neighbors=='all':\n",
    "                neighbor = set(range(self.npix))-{ind}\n",
    "            else:\n",
    "                neighbor = []\n",
    "            #neighbor = np.asarray(neighbor)\n",
    "            col_index += list(neighbor)\n",
    "        col_index = np.asarray(col_index)\n",
    "        if neighbors == 'all':\n",
    "            neighbors = self.npix - 1\n",
    "        row_index = np.repeat(np.arange(self.npix), neighbors)\n",
    "        \n",
    "        keep = (col_index < self.npix)\n",
    "        keep &= (col_index >= 0)\n",
    "        col_index = col_index[keep]\n",
    "        row_index = row_index[keep]\n",
    "        \n",
    "        if w_mat is not 'one':\n",
    "            \n",
    "            if dist=='geodesic':\n",
    "                distances = np.zeros(len(row_index))\n",
    "                for i, (pos1, pos2) in enumerate(zip(coords[row_index], coords[col_index])):\n",
    "                    d1, d2 = hp.rotator.vec2dir(pos1.T, lonlat=False).T, hp.rotator.vec2dir(pos2.T, lonlat=False).T\n",
    "                    distances[i] = hp.rotator.angdist(d1, d2, lonlat=False)\n",
    "            else:\n",
    "                distances = np.sum((coords[row_index] - coords[col_index])**2, axis=1)\n",
    "\n",
    "            def fun(x):\n",
    "                \n",
    "                val = np.abs(np.arange(x)-x//2)\n",
    "                val = 0.3+val**1/200\n",
    "                return val\n",
    "            if delta != 'one':\n",
    "                distances[::2] = distances[::2]*4#*fun(len(distances[1::2]))\n",
    "            # Compute similarities / edge weights.\n",
    "            kernel_width = np.mean(distances)\n",
    "            kernel_width2 = np.median(distances)\n",
    "\n",
    "            slope = (kernel_width2*(0.95-affine))/(kernel_width2*0.95)\n",
    "            #distances[distances<(kernel_width2*0.95)] = affine*kernel_width2 + distances[distances<(kernel_width2*0.95)] * slope\n",
    "            distances[distances>(kernel_width2*0.95)] = affine*kernel_width2 + distances[distances>(kernel_width2*0.95)] * slope\n",
    "            if isinstance(w_mat, int):\n",
    "                kernel_width = weights\n",
    "            weights = np.exp(-distances / (2 * kernel_width2))\n",
    "#             weights = 1/distances\n",
    "\n",
    "            plt.hist(distances, 100)\n",
    "        if w_mat == 'one':\n",
    "            weights = np.ones((len(row_index),))\n",
    "            near_pole = col_index<2\n",
    "            near_pole &= col_index>self.lat-2\n",
    "            weights[near_pole] *= 0.1\n",
    "\n",
    "        # Similarity proposed by Renata & Pascal, ICCV 2017.\n",
    "        # weights = 1 / distances\n",
    "\n",
    "        # Build the sparse matrix.\n",
    "        W = sparse.csr_matrix(\n",
    "            (weights, (row_index, col_index)), shape=(self.npix, self.npix), dtype=np.float32)\n",
    "        \n",
    "        plotting = {\"limits\": np.array([-1, 1, -1, 1, -1, 1])}\n",
    "        super(SphereEquiangular, self).__init__(W=W, coords=coords,\n",
    "                                     plotting=plotting, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = SphereEquiangular(8, 'SOFT', 'connected', neighbors='full', w_mat=None, dist='geodesic', affine=0.0, delta = 'acc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(g.e[:25], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.compute_laplacian(\"combinatorial\")\n",
    "plt.rcParams['figure.figsize'] = [10, 10]\n",
    "fig=plt.figure()\n",
    "ax=plt.subplot(111, projection='3d')\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "g.set_coordinates(g.U[:,1:4])\n",
    "g.plot(vertex_size=10, ax=ax)\n",
    "\n",
    "X, Y, Z = g.U[:,1], g.U[:,2], g.U[:,3]\n",
    "# Create cubic bounding box to simulate equal aspect ratio\n",
    "max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max()\n",
    "Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(X.max()+X.min())\n",
    "Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(Y.max()+Y.min())\n",
    "Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(Z.max()+Z.min())\n",
    "# Comment or uncomment following both lines to test the fake bounding box:\n",
    "for xb, yb, zb in zip(Xb, Yb, Zb):\n",
    "    ax.plot([xb], [yb], [zb], 'w')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui = SphereEquiangular(8, 'SOFT', 'connected', neighbors=4, w_mat=None, dist='geodesic', affine=0.0, delta = 'acc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui = SphereEquiangular(8, 'SOFT', 'connected', neighbors=4, w_mat=None, dist='geodesic', affine=0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "graphEqui.plot()\n",
    "plt.figure()\n",
    "plt.spy(graphEqui.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.plot(graphEqui.e[:25], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "graphEqui.compute_laplacian(\"combinatorial\")\n",
    "plt.rcParams['figure.figsize'] = [10, 10]\n",
    "fig=plt.figure()\n",
    "ax=plt.subplot(111, projection='3d')\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "graphEqui.set_coordinates(graphEqui.U[:,1:4])\n",
    "graphEqui.plot(vertex_size=10, ax=ax)\n",
    "\n",
    "X, Y, Z = graphEqui.U[:,1], graphEqui.U[:,2], graphEqui.U[:,3]\n",
    "# Create cubic bounding box to simulate equal aspect ratio\n",
    "max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max()\n",
    "Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(X.max()+X.min())\n",
    "Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(Y.max()+Y.min())\n",
    "Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(Z.max()+Z.min())\n",
    "# Comment or uncomment following both lines to test the fake bounding box:\n",
    "for xb, yb, zb in zip(Xb, Yb, Zb):\n",
    "    ax.plot([xb], [yb], [zb], 'w')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "graphEqui.compute_laplacian(\"combinatorial\")\n",
    "plt.rcParams['figure.figsize'] = [10, 10]\n",
    "fig=plt.figure()\n",
    "ax=plt.subplot(111)\n",
    "ax.axis('equal')\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "graphEqui.set_coordinates(graphEqui.U[:,1:3])\n",
    "graphEqui.plot(vertex_size=10, ax=ax)\n",
    "ax.set_title('')\n",
    "\n",
    "# X, Y, Z = graphCyl.U[:,1], graphCyl.U[:,2], graphCyl.U[:,3]\n",
    "# # Create cubic bounding box to simulate equal aspect ratio\n",
    "# max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max()\n",
    "# Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(X.max()+X.min())\n",
    "# Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(Y.max()+Y.min())\n",
    "# Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(Z.max()+Z.min())\n",
    "# # Comment or uncomment following both lines to test the fake bounding box:\n",
    "# for xb, yb, zb in zip(Xb, Yb, Zb):\n",
    "#     ax.plot([xb], [yb], [zb], 'w')\n",
    "plt.savefig('./figures/equi_embeded_1sh.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui.set_coordinates(graphEqui.U[:,1:4])\n",
    "plt.rcParams['figure.figsize'] = [15, 5]\n",
    "#fig = plt.figure()\n",
    "ax1 = plt.subplot(131, projection='3d')\n",
    "ax2 = plt.subplot(132, projection='3d')\n",
    "ax3 = plt.subplot(133, projection='3d')\n",
    "ax1.axis('equal')\n",
    "ax2.axis('equal')\n",
    "ax3.axis('equal')\n",
    "graphEqui.plot_signal(graphEqui.U[:,1], ax=ax1, title='m=0', edges=False)\n",
    "graphEqui.plot_signal(graphEqui.U[:,2], ax=ax2, title='m=1', edges=False)\n",
    "graphEqui.plot_signal(graphEqui.U[:,3], ax=ax3, title='m=-1', edges=False)\n",
    "plt.savefig('./figures/equi_sh.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui.plot_signal(graphEqui.U[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui.set_coordinates(graphEqui.U[:,2:4])\n",
    "graphEqui.plot(vertex_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphEqui.set_coordinates(graphEqui.U[:,1:3])\n",
    "graphEqui.plot(vertex_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def geodesic_distance(pos1, pos2):\n",
    "    \"\"\"\n",
    "        posx is pixel position (lat, lon) (ndarray)\n",
    "        return geodesic distance between two pixels on the sphere\n",
    "    \"\"\"\n",
    "    sin_lat1, sin_lon1 = np.sin(pos1)\n",
    "    cos_lat1, cos_lon1 = np.cos(pos1)\n",
    "    sin_lat2, sin_lon2 = np.sin(pos2)\n",
    "    cos_lat2, cos_lon2 = np.cos(pos2)\n",
    "    delta_lat, delta_lon = np.abs(pos1-pos2)\n",
    "    sqrt = np.sqrt((cos_lat2*np.sin(delta_lon))**2+(cos_lat1*sin_lat2-sin_lat1*cos_lat2*np.cos(delta_lon))**2)\n",
    "    div = sqrt/(sin_lat1*sin_lat2+cos_lat1*cos_lat2*np.cos(delta_lon))\n",
    "    return np.arctan(div)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cylinder graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphCyl = SphereEquiangular(10, 'SOFT', 'disconnected', neighbors=4, w_mat='one', \n",
    "                               dist='euclidean', geometry='cylinder', affine=0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphCyl.plot()\n",
    "plt.figure()\n",
    "plt.spy(graphCyl.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.plot(graphCyl.e[:25], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "graphCyl.compute_laplacian(\"normalized\")\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "graphCyl.set_coordinates(graphCyl.U[:,2:5])\n",
    "graphCyl.plot(vertex_size=10)\n",
    "fig=plt.figure()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Icosahedron graph\n",
    "See jiang code ([github](https://github.com/maxjiang93/ugscnn/blob/master/meshcnn/mesh.py)) for more precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unique_rows(data, digits=None):\n",
    "    \"\"\"\n",
    "    Returns indices of unique rows. It will return the\n",
    "    first occurrence of a row that is duplicated:\n",
    "    [[1,2], [3,4], [1,2]] will return [0,1]\n",
    "    Parameters\n",
    "    ---------\n",
    "    data: (n,m) set of floating point data\n",
    "    digits: how many digits to consider for the purposes of uniqueness\n",
    "    Returns\n",
    "    --------\n",
    "    unique:  (j) array, index in data which is a unique row\n",
    "    inverse: (n) length array to reconstruct original\n",
    "                 example: unique[inverse] == data\n",
    "    \"\"\"\n",
    "    hashes = hashable_rows(data, digits=digits)\n",
    "    garbage, unique, inverse = np.unique(hashes,\n",
    "                                         return_index=True,\n",
    "                                         return_inverse=True)\n",
    "    return unique, inverse\n",
    "\n",
    "def hashable_rows(data, digits=None):\n",
    "    \"\"\"\n",
    "    We turn our array into integers based on the precision\n",
    "    given by digits and then put them in a hashable format.\n",
    "    Parameters\n",
    "    ---------\n",
    "    data:    (n,m) input array\n",
    "    digits:  how many digits to add to hash, if data is floating point\n",
    "             If none, TOL_MERGE will be turned into a digit count and used.\n",
    "    Returns\n",
    "    ---------\n",
    "    hashable:  (n) length array of custom data which can be sorted\n",
    "                or used as hash keys\n",
    "    \"\"\"\n",
    "    # if there is no data return immediatly\n",
    "    if len(data) == 0:\n",
    "        return np.array([])\n",
    "\n",
    "    # get array as integer to precision we care about\n",
    "    as_int = float_to_int(data, digits=digits)\n",
    "\n",
    "    # if it is flat integers already, return\n",
    "    if len(as_int.shape) == 1:\n",
    "        return as_int\n",
    "\n",
    "    # if array is 2D and smallish, we can try bitbanging\n",
    "    # this is signifigantly faster than the custom dtype\n",
    "    if len(as_int.shape) == 2 and as_int.shape[1] <= 4:\n",
    "        # time for some righteous bitbanging\n",
    "        # can we pack the whole row into a single 64 bit integer\n",
    "        precision = int(np.floor(64 / as_int.shape[1]))\n",
    "        # if the max value is less than precision we can do this\n",
    "        if np.abs(as_int).max() < 2**(precision - 1):\n",
    "            # the resulting package\n",
    "            hashable = np.zeros(len(as_int), dtype=np.int64)\n",
    "            # loop through each column and bitwise xor to combine\n",
    "            # make sure as_int is int64 otherwise bit offset won't work\n",
    "            for offset, column in enumerate(as_int.astype(np.int64).T):\n",
    "                # will modify hashable in place\n",
    "                np.bitwise_xor(hashable,\n",
    "                               column << (offset * precision),\n",
    "                               out=hashable)\n",
    "            return hashable\n",
    "\n",
    "    # reshape array into magical data type that is weird but hashable\n",
    "    dtype = np.dtype((np.void, as_int.dtype.itemsize * as_int.shape[1]))\n",
    "    # make sure result is contiguous and flat\n",
    "    hashable = np.ascontiguousarray(as_int).view(dtype).reshape(-1)\n",
    "    return hashable\n",
    "\n",
    "def float_to_int(data, digits=None, dtype=np.int32):\n",
    "    \"\"\"\n",
    "    Given a numpy array of float/bool/int, return as integers.\n",
    "    Parameters\n",
    "    -------------\n",
    "    data:   (n, d) float, int, or bool data\n",
    "    digits: float/int precision for float conversion\n",
    "    dtype:  numpy dtype for result\n",
    "    Returns\n",
    "    -------------\n",
    "    as_int: data, as integers\n",
    "    \"\"\"\n",
    "    # convert to any numpy array\n",
    "    data = np.asanyarray(data)\n",
    "\n",
    "    # if data is already an integer or boolean we're done\n",
    "    # if the data is empty we are also done\n",
    "    if data.dtype.kind in 'ib' or data.size == 0:\n",
    "        return data.astype(dtype)\n",
    "\n",
    "    # populate digits from kwargs\n",
    "    if digits is None:\n",
    "        digits = decimal_to_digits(1e-8)\n",
    "    elif isinstance(digits, float) or isinstance(digits, np.float):\n",
    "        digits = decimal_to_digits(digits)\n",
    "    elif not (isinstance(digits, int) or isinstance(digits, np.integer)):\n",
    "        log.warn('Digits were passed as %s!', digits.__class__.__name__)\n",
    "        raise ValueError('Digits must be None, int, or float!')\n",
    "\n",
    "    # data is float so convert to large integers\n",
    "    data_max = np.abs(data).max() * 10**digits\n",
    "    # ignore passed dtype if we have something large\n",
    "    dtype = [np.int32, np.int64][int(data_max > 2**31)]\n",
    "    # multiply by requested power of ten\n",
    "    # then subtract small epsilon to avoid \"go either way\" rounding\n",
    "    # then do the rounding and convert to integer\n",
    "    as_int = np.round((data * 10 ** digits) - 1e-6).astype(dtype)\n",
    "\n",
    "    return as_int\n",
    "\n",
    "\n",
    "def decimal_to_digits(decimal, min_digits=None):\n",
    "    \"\"\"\n",
    "    Return the number of digits to the first nonzero decimal.\n",
    "    Parameters\n",
    "    -----------\n",
    "    decimal:    float\n",
    "    min_digits: int, minumum number of digits to return\n",
    "    Returns\n",
    "    -----------\n",
    "    digits: int, number of digits to the first nonzero decimal\n",
    "    \"\"\"\n",
    "    digits = abs(int(np.log10(decimal)))\n",
    "    if min_digits is not None:\n",
    "        digits = np.clip(digits, min_digits, 20)\n",
    "    return digits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SphereIcosahedron(NNGraph):\n",
    "    def __init__(self, level, sampling='vertex', **kwargs):\n",
    "        from collections import deque\n",
    "        ## sampling in ['vertex', 'face']\n",
    "        self.intp = None\n",
    "        PHI = (1 + np.sqrt(5))/2\n",
    "        radius = np.sqrt(PHI**2+1)\n",
    "        coords = np.zeros((12,3))\n",
    "        pointUpFor = deque([0, 1, PHI])\n",
    "        pointUpBack = deque([0, -1, PHI])\n",
    "        pointDownFor = deque([0, 1, -PHI])\n",
    "        pointDownBack = deque([0, -1, -PHI])\n",
    "        for i in range(3):\n",
    "            coords[4*i] = pointUpFor\n",
    "            coords[4*i+1] = pointUpBack\n",
    "            coords[4*i+2] = pointDownFor\n",
    "            coords[4*i+3] = pointDownBack\n",
    "            pointUpFor.rotate()\n",
    "            pointUpBack.rotate()\n",
    "            pointDownFor.rotate()\n",
    "            pointDownBack.rotate()\n",
    "        coords = coords/radius\n",
    "        faces = [1, 2, 7, 1, 7, 10, 1, 10, 9, 1, 9, 5, 1, 5, 2, 2, 7, 12, 12, 7, 8, 7, 8, 10, 8, 10, 3, 10, 3, 9, 3, 9, 6, 9, 6, 5, 6, 5, 11, 5, 11, 2, 11, 2, 12, 4, 11, 12, 4, 12, 8, 4, 8, 3, 4, 3, 6, 4, 6, 11]\n",
    "        self.faces = np.reshape(faces, (20,3))-1\n",
    "        self.level = level\n",
    "        self.coords = coords\n",
    "        \n",
    "        self.coords = self._upward(coords, self.faces)\n",
    "        ## rotate icosahedron?\n",
    "        for i in range(level):\n",
    "            self.divide()\n",
    "            self.normalize()\n",
    "        \n",
    "        if sampling=='face':\n",
    "            self.coords = self.coords[self.faces].mean(axis=1)\n",
    "            \n",
    "        self.lat, self.long = self.xyz2latlong()\n",
    "#         theta = [0] + 5*[np.pi/2-np.arctan(0.5)] + 5*[np.pi/2+np.arctan(0.5)] + [np.pi]\n",
    "#         phi = [0] + np.linspace(0, 2*np.pi, 5, endpoint=False).tolist() +  (np.linspace(0, 2*np.pi, 5, endpoint=False)+(np.pi/5)).tolist() + [0]\n",
    "        \n",
    "        self.npix = len(self.coords)\n",
    "        self.nf = 20 * 4**self.level\n",
    "        self.ne = 30 * 4**self.level\n",
    "        self.nv = self.ne - self.nf + 2\n",
    "        self.nv_prev = int((self.ne / 4) - (self.nf / 4) + 2)\n",
    "        self.nv_next = int((self.ne * 4) - (self.nf * 4) + 2)\n",
    "        #W = np.ones((self.npix, self.npix))\n",
    "        \n",
    "        neighbours = 3 if 'face' in sampling else (5 if level == 0 else 6)\n",
    "        super(SphereIcosahedron, self).__init__(self.coords, k=neighbours, **kwargs)\n",
    "        \n",
    "    def divide(self):\n",
    "        \"\"\"\n",
    "        Subdivide a mesh into smaller triangles.\n",
    "        \"\"\"\n",
    "        faces = self.faces\n",
    "        vertices = self.coords\n",
    "        face_index = np.arange(len(faces))\n",
    "\n",
    "        # the (c,3) int set of vertex indices\n",
    "        faces = faces[face_index]\n",
    "        # the (c, 3, 3) float set of points in the triangles\n",
    "        triangles = vertices[faces]\n",
    "        # the 3 midpoints of each triangle edge vstacked to a (3*c, 3) float\n",
    "        src_idx = np.vstack([faces[:, g] for g in [[0, 1], [1, 2], [2, 0]]])\n",
    "        mid = np.vstack([triangles[:, g, :].mean(axis=1) for g in [[0, 1],\n",
    "                                                                   [1, 2],\n",
    "                                                                   [2, 0]]])\n",
    "        mid_idx = (np.arange(len(face_index) * 3)).reshape((3, -1)).T\n",
    "        # for adjacent faces we are going to be generating the same midpoint\n",
    "        # twice, so we handle it here by finding the unique vertices\n",
    "        unique, inverse = unique_rows(mid)\n",
    "\n",
    "        mid = mid[unique]\n",
    "        src_idx = src_idx[unique]\n",
    "        mid_idx = inverse[mid_idx] + len(vertices)\n",
    "        # the new faces, with correct winding\n",
    "        f = np.column_stack([faces[:, 0], mid_idx[:, 0], mid_idx[:, 2],\n",
    "                             mid_idx[:, 0], faces[:, 1], mid_idx[:, 1],\n",
    "                             mid_idx[:, 2], mid_idx[:, 1], faces[:, 2],\n",
    "                             mid_idx[:, 0], mid_idx[:, 1], mid_idx[:, 2], ]).reshape((-1, 3))\n",
    "        # add the 3 new faces per old face\n",
    "        new_faces = np.vstack((faces, f[len(face_index):]))\n",
    "        # replace the old face with a smaller face\n",
    "        new_faces[face_index] = f[:len(face_index)]\n",
    "\n",
    "        new_vertices = np.vstack((vertices, mid))\n",
    "        # source ids\n",
    "        nv = vertices.shape[0]\n",
    "        identity_map = np.stack((np.arange(nv), np.arange(nv)), axis=1)\n",
    "        src_id = np.concatenate((identity_map, src_idx), axis=0)\n",
    "\n",
    "        self.coords = new_vertices\n",
    "        self.faces = new_faces\n",
    "        self.intp = src_id\n",
    "        \n",
    "    def normalize(self, radius=1):\n",
    "        '''\n",
    "        Reproject to spherical surface\n",
    "        '''\n",
    "        vectors = self.coords\n",
    "        scalar = (vectors ** 2).sum(axis=1)**.5\n",
    "        unit = vectors / scalar.reshape((-1, 1))\n",
    "        offset = radius - scalar\n",
    "        self.coords += unit * offset.reshape((-1, 1))\n",
    "        \n",
    "    def xyz2latlong(self):\n",
    "        x, y, z = self.coords[:, 0], self.coords[:, 1], self.coords[:, 2]\n",
    "        long = np.arctan2(y, x)\n",
    "        xy2 = x**2 + y**2\n",
    "        lat = np.arctan2(z, np.sqrt(xy2))\n",
    "        return lat, long\n",
    "    \n",
    "    def _upward(self, V_ico, F_ico, ind=11):\n",
    "        V0 = V_ico[ind]\n",
    "        Z0 = np.array([0, 0, 1])\n",
    "        k = np.cross(V0, Z0)\n",
    "        ct = np.dot(V0, Z0)\n",
    "        st = -np.linalg.norm(k)\n",
    "        R = self._rot_matrix(k, ct, st)\n",
    "        V_ico = V_ico.dot(R)\n",
    "        # rotate a neighbor to align with (+y)\n",
    "        ni = self._find_neighbor(F_ico, ind)[0]\n",
    "        vec = V_ico[ni].copy()\n",
    "        vec[2] = 0\n",
    "        vec = vec/np.linalg.norm(vec)\n",
    "        y_ = np.eye(3)[1]\n",
    "\n",
    "        k = np.eye(3)[2]\n",
    "        crs = np.cross(vec, y_)\n",
    "        ct = -np.dot(vec, y_)\n",
    "        st = -np.sign(crs[-1])*np.linalg.norm(crs)\n",
    "        R2 = self._rot_matrix(k, ct, st)\n",
    "        V_ico = V_ico.dot(R2)\n",
    "        return V_ico\n",
    "    \n",
    "    def _find_neighbor(self, F, ind):\n",
    "        \"\"\"find a icosahedron neighbor of vertex i\"\"\"\n",
    "        FF = [F[i] for i in range(F.shape[0]) if ind in F[i]]\n",
    "        FF = np.concatenate(FF)\n",
    "        FF = np.unique(FF)\n",
    "        neigh = [f for f in FF if f != ind]\n",
    "        return neigh\n",
    "\n",
    "    def _rot_matrix(self, rot_axis, cos_t, sin_t):\n",
    "        k = rot_axis / np.linalg.norm(rot_axis)\n",
    "        I = np.eye(3)\n",
    "\n",
    "        R = []\n",
    "        for i in range(3):\n",
    "            v = I[i]\n",
    "            vr = v*cos_t+np.cross(k, v)*sin_t+k*(k.dot(v))*(1-cos_t)\n",
    "            R.append(vr)\n",
    "        R = np.stack(R, axis=-1)\n",
    "        return R\n",
    "\n",
    "    def _ico_rot_matrix(self, ind):\n",
    "        \"\"\"\n",
    "        return rotation matrix to perform permutation corresponding to \n",
    "        moving a certain icosahedron node to the top\n",
    "        \"\"\"\n",
    "        v0_ = self.v0.copy()\n",
    "        f0_ = self.f0.copy()\n",
    "        V0 = v0_[ind]\n",
    "        Z0 = np.array([0, 0, 1])\n",
    "\n",
    "        # rotate the point to the top (+z)\n",
    "        k = np.cross(V0, Z0)\n",
    "        ct = np.dot(V0, Z0)\n",
    "        st = -np.linalg.norm(k)\n",
    "        R = self._rot_matrix(k, ct, st)\n",
    "        v0_ = v0_.dot(R)\n",
    "\n",
    "        # rotate a neighbor to align with (+y)\n",
    "        ni = self._find_neighbor(f0_, ind)[0]\n",
    "        vec = v0_[ni].copy()\n",
    "        vec[2] = 0\n",
    "        vec = vec/np.linalg.norm(vec)\n",
    "        y_ = np.eye(3)[1]\n",
    "\n",
    "        k = np.eye(3)[2]\n",
    "        crs = np.cross(vec, y_)\n",
    "        ct = np.dot(vec, y_)\n",
    "        st = -np.sign(crs[-1])*np.linalg.norm(crs)\n",
    "\n",
    "        R2 = self._rot_matrix(k, ct, st)\n",
    "        return R.dot(R2)\n",
    "\n",
    "    def _rotseq(self, V, acc=9):\n",
    "        \"\"\"sequence to move an original node on icosahedron to top\"\"\"\n",
    "        seq = []\n",
    "        for i in range(11):\n",
    "            Vr = V.dot(self._ico_rot_matrix(i))\n",
    "            # lexsort\n",
    "            s1 = np.lexsort(np.round(V.T, acc))\n",
    "            s2 = np.lexsort(np.round(Vr.T, acc))\n",
    "            s = s1[np.argsort(s2)]\n",
    "            seq.append(s)\n",
    "        return tuple(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "gIcosa = SphereIcosahedron(2, sampling='vertex')\n",
    "gIcosa.plot(ax=ax)\n",
    "ax.view_init(0, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.spy(gIcosa.W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(gIcosa.e[:30], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gIcosa.compute_laplacian(\"combinatorial\")\n",
    "#graphCyl.compute_fourier_basis(recompute=True)\n",
    "gIcosa.set_coordinates(gIcosa.U[:,1:4])\n",
    "gIcosa.plot(vertex_size=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
